In [12]:
import numpy as np
import nolearn
import sklearn.linear_model as lm
import scipy.stats as sps
import math
import pandas as pd

from decimal import Decimal
# from lasagne import layers, nonlinearities
# from lasagne.updates import nesterov_momentum
# from lasagne import layers
# from nolearn.lasagne import NeuralNet
from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor, GradientBoostingRegressor, ExtraTreesRegressor, BaggingRegressor
from sklearn.cross_validation import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.svm import SVR
from sklearn.externals import joblib
from sklearn.utils import resample
from sklearn.preprocessing import LabelBinarizer

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [13]:
import custom_funcs as cf
cf.init_seaborn('white', 'notebook')
from isoelectric_point import isoelectric_points
from molecular_weight import molecular_weights

In [14]:
# Read in the protease inhibitor data
data, drug_cols, feat_cols = cf.read_data('hiv-protease-data.csv', n_data_cols=8)
print(len(data))
# Read in the consensus data
consensus_map = cf.read_consensus('hiv-protease-consensus.fasta')

# Clean the data
data = cf.clean_data(data, feat_cols, consensus_map)

# Identify feature columns
data = cf.drop_ambiguous_sequences(data, feat_cols)
data.dropna(inplace=True, subset=feat_cols)
data


1808
Out[14]:
FPV ATV IDV LPV NFV SQV TPV DRV P1 P2 ... P90 P91 P92 P93 P94 P95 P96 P97 P98 P99
SeqID
4432 1.5 NaN 1.0 NaN 2.2 1.1 NaN NaN P Q ... L T Q I G C T L N F
4664 3.1 NaN 8.7 NaN 32.0 16.9 NaN NaN P Q ... M T Q I G C T L N F
5221 NaN NaN 0.8 0.8 1.2 0.7 NaN NaN P Q ... L T Q I G C T L N F
5279 8.3 79.0 16.0 12.0 600.0 1000.0 NaN NaN P Q ... M T Q I G C T L N F
5444 2.7 21.0 24.0 6.1 42.0 132.0 NaN NaN P Q ... M T Q I G C T L N F
5462 2.1 16.0 12.0 22.0 15.0 82.0 NaN NaN P Q ... L T Q I G C T L N F
5464 2.1 NaN 22.2 7.8 24.7 104.8 NaN NaN P Q ... M T Q L G C T L N F
5681 NaN NaN 26.0 25.0 37.0 7.4 NaN NaN P Q ... M T Q L G C T L N F
6024 NaN NaN 8.3 3.0 22.0 3.4 NaN NaN P Q ... M T Q L G C T L N F
6028 NaN NaN 16.0 20.0 37.0 7.9 NaN NaN P Q ... M T Q I G C T L N F
7042 11.0 18.0 28.0 17.0 53.0 62.0 NaN NaN P Q ... M T Q L G C T L N F
7085 0.4 2.0 1.9 0.9 3.7 2.5 NaN NaN P Q ... M T Q I G C T L N F
7103 NaN NaN 0.7 0.7 11.0 0.4 NaN NaN P Q ... L T Q I G C T L N F
7119 1.4 0.9 1.0 0.8 1.6 0.8 NaN NaN P Q ... L T Q I G C T L N F
7412 6.2 NaN 12.0 NaN 10.2 591.5 NaN NaN P Q ... L T Q I G C T L N F
7430 2.8 NaN 48.9 NaN 80.7 42.1 NaN NaN P Q ... M T Q I G C T L N F
7443 2.3 NaN 12.0 NaN 11.0 574.2 NaN NaN P Q ... L T Q L G C T L N F
8188 4.7 29.0 25.0 34.0 28.0 147.0 NaN NaN P Q ... L T Q I G C T L N F
8468 1.4 11.0 17.0 4.4 26.0 20.0 NaN NaN P Q ... M T Q I G C T L N F
8506 5.4 15.0 19.0 7.2 34.0 70.0 NaN NaN P Q ... M T Q I G C T L N F
8626 11.0 15.0 33.0 34.0 56.0 1.5 NaN NaN P Q ... M T Q L G C T L N F
8654 NaN NaN NaN NaN 7.0 1.0 NaN NaN P Q ... L T Q I G C T L N F
8658 NaN NaN NaN NaN 4.0 1.0 NaN NaN P Q ... L T Q I G C T L N F
8660 NaN NaN NaN NaN 37.0 5.0 NaN NaN P Q ... M T Q L G C T L N F
8666 NaN NaN NaN NaN 2.0 1.0 NaN NaN P Q ... L T Q L G C T L N F
8674 NaN NaN NaN NaN 1.0 1.0 NaN NaN P Q ... L T Q I G C T L N F
9431 NaN NaN 2.8 0.8 12.0 0.9 NaN NaN P Q ... L T Q I G C T L N F
9556 NaN NaN 1.2 1.2 3.0 1.0 NaN NaN P Q ... L T Q L G C T L N F
9564 0.4 2.2 0.8 0.5 24.0 0.8 NaN NaN P Q ... L T Q I G C T L N F
9706 NaN NaN 0.3 0.3 0.4 0.4 NaN NaN P Q ... L T Q I G C T L N F
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
235719 0.6 0.6 0.4 0.6 0.7 0.5 0.8 0.5 P Q ... L T Q I G C T L N F
235721 4.7 2.8 3.4 4.4 5.1 5.2 3.2 1.7 P Q ... L T K I G C T L N F
235725 0.5 0.8 0.6 0.6 0.7 0.6 0.9 0.5 P Q ... L T Q I G C T L N F
235729 1.1 1.1 1.0 1.2 1.2 1.1 0.8 1.0 P Q ... L T Q I G C T L N F
235733 23.0 115.0 35.0 102.0 68.0 184.0 NaN NaN P Q ... M K Q L G C T L N F
235739 0.5 0.7 0.8 0.7 0.8 0.8 0.8 0.7 P Q ... L T Q I G C T L N F
257923 1.1 0.8 1.0 1.0 0.9 1.0 1.1 1.6 P Q ... L T Q I G C T L N F
257927 0.4 0.6 0.4 0.4 0.8 0.4 0.5 0.3 P Q ... L T Q I G C T L N F
257929 0.1 0.4 0.3 0.3 0.3 0.4 0.4 0.4 P Q ... L T Q I G C T L N F
257933 0.6 0.9 1.0 0.7 1.0 0.6 0.6 0.6 P Q ... L T Q L G C T L N F
257935 0.8 0.9 0.8 0.8 0.8 0.8 0.8 0.7 P Q ... L T Q I G C T L N F
257939 200.0 204.0 54.0 98.0 32.0 20.0 4.2 117.0 P Q ... M T Q I G C T L N F
257941 0.6 1.0 0.8 0.7 0.9 0.6 0.7 0.5 P Q ... L T Q I G C T L N F
257947 0.6 1.0 1.1 0.9 1.3 1.3 1.7 1.6 P Q ... L T Q I G C T L N F
257957 46.0 200.0 200.0 200.0 89.0 200.0 1.3 33.0 P Q ... L T Q L G C T L N F
257963 0.6 0.7 0.7 0.9 1.1 0.9 1.0 1.0 P Q ... L T Q I G C T L N F
258503 0.2 8.3 3.3 1.4 7.1 1.9 0.6 0.3 P Q ... L T R L G C T L N F
258505 0.7 0.8 0.8 0.8 0.9 0.9 0.9 1.0 P Q ... L T Q I G C T L N F
258507 0.5 0.8 0.8 0.8 1.2 0.7 0.9 0.6 P Q ... L T Q I G C T L N F
258509 2.5 5.0 4.5 2.0 9.1 3.5 2.4 1.3 P Q ... M T Q L G C T L N F
259173 0.7 1.0 1.2 1.1 2.0 1.0 1.0 0.8 P Q ... L T Q L G C T L N F
259175 0.9 0.8 1.0 1.0 0.8 0.8 0.7 0.8 P Q ... L T Q I G C T L N F
259181 2.6 9.3 21.0 6.8 13.0 21.0 1.4 1.5 P Q ... M T Q I G C T L N F
259191 1.1 27.0 30.0 36.0 36.0 200.0 0.6 0.6 P Q ... M T Q L G C T L N F
259195 1.1 1.5 1.6 1.1 1.4 1.3 1.5 0.9 P Q ... L T Q L G C T L N F
259199 0.5 0.8 0.7 0.6 1.1 0.6 0.6 0.6 P Q ... L T Q L G C T L N F
259215 0.6 0.8 1.0 0.7 1.3 0.6 0.7 0.6 P Q ... L T Q L G C T L N F
259227 6.3 6.2 6.3 3.4 20.5 5.3 6.7 2.9 P Q ... M T Q I G C T L N F
259253 0.9 0.8 0.9 0.9 1.5 0.7 0.7 0.6 P Q ... L T Q I G C T L N F
259257 0.8 0.8 0.8 0.6 1.7 0.7 1.1 0.5 P Q ... L T Q L G C T L N F

802 rows × 107 columns

Audience Choice

Which drug would you like to see?

  • FPV
  • ATV
  • IDV
  • LPV
  • NFV
  • SQV
  • TPV
  • DRV

In [15]:
# Audience choice: Which drug would you like?
print(drug_cols)

DRUG = 'FPV'


Index(['FPV', 'ATV', 'IDV', 'LPV', 'NFV', 'SQV', 'TPV', 'DRV'], dtype='object')

In [16]:
feat_cols


Out[16]:
Index(['P1', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8', 'P9', 'P10', 'P11',
       'P12', 'P13', 'P14', 'P15', 'P16', 'P17', 'P18', 'P19', 'P20', 'P21',
       'P22', 'P23', 'P24', 'P25', 'P26', 'P27', 'P28', 'P29', 'P30', 'P31',
       'P32', 'P33', 'P34', 'P35', 'P36', 'P37', 'P38', 'P39', 'P40', 'P41',
       'P42', 'P43', 'P44', 'P45', 'P46', 'P47', 'P48', 'P49', 'P50', 'P51',
       'P52', 'P53', 'P54', 'P55', 'P56', 'P57', 'P58', 'P59', 'P60', 'P61',
       'P62', 'P63', 'P64', 'P65', 'P66', 'P67', 'P68', 'P69', 'P70', 'P71',
       'P72', 'P73', 'P74', 'P75', 'P76', 'P77', 'P78', 'P79', 'P80', 'P81',
       'P82', 'P83', 'P84', 'P85', 'P86', 'P87', 'P88', 'P89', 'P90', 'P91',
       'P92', 'P93', 'P94', 'P95', 'P96', 'P97', 'P98', 'P99'],
      dtype='object')

In [17]:
# Split data into predictor variables and dependent variables.
# Predictors are the sequence features
# Dependent are the drug resistance values
def binarize_data(data, feat_cols, DRUG):
    data = resample(data)
    X, Y = cf.split_data_xy(data, feat_cols, DRUG)

    # Binarize the sequence features such that there are 99 x 20 columns in total.
    lb = LabelBinarizer()
    lb.fit(list('CHIMSVAGLPTRFYWDNEQK'))

    X_binarized = pd.DataFrame()

    for col in X.columns:
        binarized_cols = lb.transform(X[col])

        for i, c in enumerate(lb.classes_):
            X_binarized[col + '_' + c] = binarized_cols[:,i]
    return X_binarized, Y

X_binarized, Y = binarize_data(data, feat_cols, DRUG)


/home/ubuntu/github/hiv-resistance-prediction/custom_funcs.py:189: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  subset.dropna(inplace=True)

In [18]:
# View distribution of drug resistance values
import matplotlib.pyplot as plt
std = (3,3)
fig = cf.plot_Y_histogram(Y, DRUG, figsize=std)



In [19]:
# Split data into training and testing set.
tts_data = X_train, X_test, Y_train, Y_test = train_test_split(X_binarized, Y)

# Train a bunch of ensemble regressors:
## Random Forest
kwargs = {'n_jobs':-1, 'n_estimators':1000}
rfr, rfr_preds, rfr_mse, rfr_r2 = cf.train_model(*tts_data, model=RandomForestRegressor, modelargs=kwargs)
## Gradient Boosting
kwargs = {'n_estimators':1000}
gbr, gbr_preds, gbr_mse, gbr_r2 = cf.train_model(*tts_data, model=GradientBoostingRegressor, modelargs=kwargs)
## AdaBoost
kwargs = {'n_estimators':1000}
abr, abr_preds, abr_mse, abr_r2 = cf.train_model(*tts_data, model=AdaBoostRegressor, modelargs=kwargs)
## ExtraTrees
etr, etr_preds, etr_mse, etr_r2 = cf.train_model(*tts_data, model=ExtraTreesRegressor)
## Bagging
bgr, bgr_preds, bgr_mse, bgr_r2 = cf.train_model(*tts_data, model=BaggingRegressor)

# Plot the results of regression
rfr_fig = cf.scatterplot_results(rfr_preds, Y_test, rfr_mse, rfr_r2, DRUG, 'Rand. Forest', figsize=std)
# plt.savefig('figures/{0} Random Forest.pdf'.format(DRUG), bbox_inches='tight')
cf.scatterplot_results(gbr_preds, Y_test, gbr_mse, gbr_r2, DRUG, 'Grad. Boost', figsize=std)
cf.scatterplot_results(abr_preds, Y_test, abr_mse, abr_r2, DRUG, 'AdaBoost', figsize=std)
cf.scatterplot_results(etr_preds, Y_test, etr_mse, etr_r2, DRUG, 'ExtraTrees', figsize=std)
cf.scatterplot_results(bgr_preds, Y_test, bgr_mse, bgr_r2, DRUG, 'Bagging', figsize=std)


Out[19]:

In [20]:
# Grab the feature importances - that is, how important a particular feature is for predicting drug resistance
cf.barplot_feature_importances(rfr, DRUG, 'Rand. Forest')
cf.barplot_feature_importances(gbr, DRUG, 'Grad. Boost')
cf.barplot_feature_importances(abr, DRUG, 'AdaBoost')
cf.barplot_feature_importances(etr, DRUG, 'ExtraTrees')
# cf.barplot_feature_importances(bgr, DRUG, 'Bagging') ## feature_importances_ do not exist for bagging


Out[20]:

In [30]:
# Extract a table version of feature importance
rfr_fi = cf.extract_mutational_importance(rfr, X_test)
gbr_fi = cf.extract_mutational_importance(gbr, X_test)
abr_fi = cf.extract_mutational_importance(abr, X_test)

# Join data to compare random forest and gradient boosting.
# joined = rfr_fi.set_index(0).join(gbr_fi.set_index(0), lsuffix='r', rsuffix='g')
# sps.spearmanr(joined['1r'], joined['1g'])

rfr_fi


Out[30]:
0 1
0 P10_L 0.435105
1 P84_I 0.076132
2 P47_I 0.059716
3 P54_I 0.045847
4 P33_F 0.037099
5 P88_S 0.034475
6 P32_I 0.020380
7 P46_M 0.017613
8 P84_V 0.017395
9 P90_M 0.013866
10 P32_V 0.013016
11 P47_V 0.009156
12 P50_V 0.008984
13 P90_L 0.008090
14 P46_I 0.007541
15 P84_A 0.007489
16 P33_L 0.006236
17 P20_K 0.005763
18 P54_M 0.003981
19 P89_L 0.003716
20 P58_E 0.003694
21 P10_V 0.003591
22 P13_V 0.003544
23 P37_D 0.003171
24 P50_I 0.003093
25 P71_V 0.003077
26 P63_L 0.003059
27 P35_G 0.003034
28 P64_I 0.003013
29 P63_P 0.002983
... ... ...
1950 P98_L 0.000000
1951 P98_M 0.000000
1952 P98_P 0.000000
1953 P98_Q 0.000000
1954 P98_R 0.000000
1955 P98_S 0.000000
1956 P98_T 0.000000
1957 P98_V 0.000000
1958 P98_W 0.000000
1959 P98_Y 0.000000
1960 P99_A 0.000000
1961 P99_C 0.000000
1962 P99_D 0.000000
1963 P99_E 0.000000
1964 P99_F 0.000000
1965 P99_G 0.000000
1966 P99_H 0.000000
1967 P99_I 0.000000
1968 P99_K 0.000000
1969 P99_L 0.000000
1970 P99_M 0.000000
1971 P99_N 0.000000
1972 P99_P 0.000000
1973 P99_Q 0.000000
1974 P99_R 0.000000
1975 P99_S 0.000000
1976 P99_T 0.000000
1977 P99_V 0.000000
1978 P99_W 0.000000
1979 P99_Y 0.000000

1980 rows × 2 columns


In [31]:
# Train a bunch of linear model learners for comparison.
brr, brr_preds, brr_mse, brr_r2 = cf.train_model(*tts_data, model=lm.BayesianRidge)
ard, ard_preds, ard_mse, ard_r2 = cf.train_model(*tts_data, model=lm.ARDRegression)
logr, logr_preds, logr_mse, logr_r2 = cf.train_model(*tts_data, model=lm.LogisticRegression)
enr, enr_preds, enr_mse, enr_r2 = cf.train_model(*tts_data, model=lm.ElasticNet)
svr, svr_preds, svr_mse, svr_r2 = cf.train_model(*tts_data, model=SVR)

# Likewise, plot the results
cf.scatterplot_results(brr_preds, Y_test, brr_mse, brr_r2, DRUG, 'Bayesian Ridge', figsize=std)
cf.scatterplot_results(ard_preds, Y_test, ard_mse, ard_r2, DRUG, 'ARD Regression', figsize=std)
cf.scatterplot_results(logr_preds, Y_test, logr_mse, logr_r2, DRUG, 'Logistic Regression', figsize=std)
cf.scatterplot_results(enr_preds, Y_test, enr_mse, enr_r2, DRUG, 'ElasticNet', figsize=std)
cf.scatterplot_results(svr_preds, Y_test, svr_mse, svr_r2, DRUG, 'SVMs', figsize=std)


/home/ericmjl/anaconda3/lib/python3.4/site-packages/matplotlib/collections.py:590: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  if self._edgecolors == str('face'):
Out[31]:

In [ ]:
# Let's now try a neural network. 
# Neural Network 1 Specification: Feed Forward ANN with 1 hidden layer.

x_train = X_train.astype(np.float32)
y_train = Y_train.astype(np.float32)
x_test = X_test.astype(np.float32)
y_test = Y_test.astype(np.float32)

net1 = NeuralNet(
    layers=[  # three layers: one hidden layer
        ('input', layers.InputLayer),
        ('hidden1', layers.DenseLayer),
        ('dropout1', layers.DropoutLayer),
        #('hidden2', layers.DenseLayer),
        #('dropout2', layers.DropoutLayer),
        ('nonlinear', layers.NonlinearityLayer),
        ('output', layers.DenseLayer),
        ],
    # layer parameters:
    input_shape=(None, x_train.shape[1]),  # 
    hidden1_num_units=math.ceil(x_train.shape[1] / 2),  # number of units in hidden layer
    hidden1_nonlinearity=nonlinearities.tanh,
    dropout1_p = 0.65,
    #hidden2_num_units=math.ceil(x_train.shape[1] / 2),
    #dropout2_p = 0.5,
    output_nonlinearity=None,  # output layer uses identity function
    output_num_units=1,  # 30 target values
    
    # optimization method:
    update=nesterov_momentum,
    update_learning_rate=0.01,
    update_momentum=0.95,

    regression=True,  # flag to indicate we're dealing with regression problem
    max_epochs=500,  # we want to train this many epochs
    verbose=1,
    )
net1.fit(x_train.values, y_train.values)


# Neural Network with 1962181 learnable parameters

## Layer information

  #  name         size
---  ---------  ------
  0  input        1980
  1  hidden1       990
  2  dropout1      990
  3  nonlinear     990
  4  output          1

  epoch    train loss    valid loss    train/val  dur
-------  ------------  ------------  -----------  -----
      1       6.57706       3.05727      2.15129  0.04s
      2       2.85664       2.29373      1.24541  0.04s
      3       2.24074       1.50510      1.48877  0.04s
      4       1.50665       1.62378      0.92787  0.04s
      5       2.35926       1.23638      1.90820  0.04s
      6       1.40743       1.20047      1.17240  0.04s
      7       1.08347       0.98490      1.10008  0.04s
      8       1.04599       1.00305      1.04281  0.04s
      9       1.14013       0.90855      1.25489  0.04s
     10       1.16671       1.11425      1.04708  0.04s
     11       1.04511       0.92597      1.12866  0.04s
     12       0.96236       0.95105      1.01189  0.04s
     13       0.98521       0.87606      1.12459  0.04s
     14       0.92985       0.85840      1.08323  0.04s
     15       0.83426       0.96119      0.86794  0.04s
     16       0.75469       0.87354      0.86395  0.04s
     17       0.81038       0.95301      0.85034  0.04s
     18       0.89549       0.83245      1.07573  0.04s
     19       0.74466       0.77892      0.95602  0.04s
     20       0.74198       0.74308      0.99853  0.04s
     21       0.72815       0.90841      0.80156  0.04s
     22       0.66428       0.74047      0.89710  0.04s
     23       0.63248       0.86467      0.73147  0.04s
     24       0.68428       0.81921      0.83529  0.04s
     25       0.66257       0.67275      0.98486  0.04s
     26       0.60111       0.79216      0.75882  0.04s
     27       0.54798       0.64771      0.84602  0.04s
     28       0.65496       0.71365      0.91775  0.04s
     29       0.62162       0.71962      0.86381  0.04s
     30       0.52862       0.62489      0.84594  0.04s
     31       0.57649       0.60407      0.95434  0.04s
     32       0.55267       0.63516      0.87013  0.04s
     33       0.50360       0.57627      0.87389  0.04s
     34       0.55935       0.54618      1.02412  0.04s
     35       0.58662       0.62857      0.93326  0.04s
     36       0.48401       0.54060      0.89533  0.04s
     37       0.51471       0.56738      0.90718  0.04s
     38       0.49885       0.67200      0.74234  0.04s
     39       0.48461       0.53317      0.90894  0.04s
     40       0.52125       0.65211      0.79933  0.04s
     41       0.48052       0.62054      0.77436  0.04s
     42       0.49661       0.61281      0.81038  0.04s
     43       0.45747       0.58250      0.78536  0.04s
     44       0.50726       0.49804      1.01852  0.04s
     45       0.62133       0.56920      1.09160  0.04s
     46       0.49267       0.58219      0.84624  0.04s
     47       0.54785       0.73250      0.74791  0.04s
     48       0.48804       0.68393      0.71358  0.04s
     49       0.55870       0.89196      0.62637  0.04s
     50       0.59605       0.56518      1.05462  0.04s
     51       0.45795       0.55962      0.81832  0.04s
     52       0.44790       0.50735      0.88281  0.04s
     53       0.51230       0.51257      0.99947  0.04s
     54       0.44714       0.49482      0.90363  0.04s
     55       0.53253       0.54573      0.97582  0.04s
     56       0.48661       0.53877      0.90317  0.04s
     57       0.49638       0.47671      1.04127  0.04s
     58       0.50103       0.51592      0.97113  0.04s
     59       0.55287       0.56719      0.97475  0.04s
     60       0.47888       0.57749      0.82923  0.04s
     61       0.56487       0.52559      1.07472  0.04s
     62       0.54847       0.46361      1.18304  0.04s
     63       0.39477       0.48057      0.82145  0.04s
     64       0.45403       0.57276      0.79271  0.04s
     65       0.51732       0.49671      1.04150  0.04s
     66       0.44926       0.50613      0.88764  0.04s
     67       0.48102       0.49807      0.96577  0.04s
     68       0.43978       0.49982      0.87988  0.04s
     69       0.46397       0.48850      0.94979  0.04s
     70       0.44727       0.50115      0.89249  0.04s
     71       0.47505       0.44994      1.05582  0.04s
     72       0.55849       0.47676      1.17142  0.04s
     73       0.51657       0.48419      1.06686  0.04s
     74       0.48601       0.47289      1.02774  0.04s
     75       0.50477       0.46400      1.08786  0.04s
     76       0.47821       0.40282      1.18716  0.04s
     77       0.46718       0.45671      1.02291  0.04s
     78       0.39576       0.42901      0.92248  0.04s
     79       0.50226       0.47054      1.06741  0.04s
     80       0.41820       0.44218      0.94575  0.04s
     81       0.47151       0.44309      1.06414  0.04s
     82       0.53264       0.45267      1.17667  0.04s
     83       0.49641       0.47252      1.05056  0.04s
     84       0.48402       0.49598      0.97588  0.04s
     85       0.44591       0.42504      1.04912  0.04s
     86       0.45757       0.48176      0.94979  0.04s
     87       0.56365       0.48475      1.16275  0.04s
     88       0.45916       0.51017      0.90000  0.04s
     89       0.43360       0.52999      0.81813  0.04s
     90       0.54294       0.66915      0.81139  0.04s
     91       0.50002       0.45305      1.10366  0.04s
     92       0.51030       0.51554      0.98985  0.04s
     93       0.45487       0.51376      0.88536  0.04s
     94       0.51307       0.44948      1.14146  0.04s
     95       0.49293       0.46525      1.05950  0.04s
     96       0.56742       0.46948      1.20861  0.04s
     97       0.49835       0.49128      1.01439  0.04s
     98       0.40542       0.44885      0.90326  0.04s
     99       0.52207       0.45181      1.15550  0.04s
    100       0.55171       0.50814      1.08573  0.04s
    101       0.57680       0.41677      1.38398  0.04s
    102       0.59705       0.42513      1.40439  0.04s
    103       0.64355       0.52865      1.21734  0.04s
    104       0.60059       0.41647      1.44209  0.04s
    105       0.56479       0.42490      1.32922  0.04s
    106       0.57367       0.42532      1.34880  0.04s
    107       0.65901       0.43571      1.51249  0.04s
    108       0.51971       0.46178      1.12544  0.04s
    109       0.46271       0.46406      0.99707  0.04s
    110       0.52505       0.49983      1.05045  0.04s
    111       0.52281       0.49545      1.05522  0.04s
    112       0.54819       0.50193      1.09217  0.04s
    113       0.60299       0.63283      0.95284  0.04s
    114       0.48342       0.49493      0.97675  0.04s
    115       0.52956       0.45375      1.16707  0.04s
    116       0.60116       0.41033      1.46505  0.04s
    117       0.48705       0.41438      1.17537  0.04s
    118       0.58589       0.47310      1.23839  0.04s
    119       0.53067       0.45743      1.16011  0.04s
    120       0.66710       0.41794      1.59616  0.04s
    121       0.48693       0.47146      1.03283  0.04s
    122       0.49610       0.43526      1.13979  0.04s
    123       0.56896       0.47403      1.20025  0.04s
    124       0.58677       0.48600      1.20733  0.04s
    125       0.43398       0.45364      0.95666  0.04s
    126       0.50155       0.47790      1.04948  0.04s
    127       0.54614       0.49100      1.11230  0.04s
    128       0.42272       0.46719      0.90481  0.04s
    129       0.57130       0.49490      1.15437  0.04s
    130       0.53071       0.49729      1.06720  0.04s
    131       0.57320       0.43785      1.30914  0.04s
    132       0.56744       0.42682      1.32946  0.04s
    133       0.49494       0.43336      1.14210  0.04s
    134       0.60395       0.46680      1.29380  0.04s
    135       0.57073       0.51631      1.10541  0.04s
    136       0.53636       0.45569      1.17703  0.04s
    137       0.58437       0.79338      0.73656  0.04s
    138       0.53448       0.47722      1.11998  0.04s
    139       0.60026       0.45703      1.31338  0.04s
    140       0.58861       0.50386      1.16821  0.04s
    141       0.54225       0.44055      1.23084  0.04s
    142       0.53138       0.45402      1.17040  0.04s
    143       0.57582       0.40009      1.43920  0.04s
    144       0.50671       0.48443      1.04598  0.04s
    145       0.54437       0.42133      1.29203  0.04s
    146       0.57629       0.44919      1.28297  0.04s
    147       0.51166       0.49715      1.02919  0.04s
    148       0.53097       0.50302      1.05555  0.04s
    149       0.51169       0.46831      1.09263  0.04s
    150       0.56517       0.42568      1.32770  0.04s
    151       0.51053       0.40597      1.25755  0.04s
    152       0.54635       0.44313      1.23292  0.04s
    153       0.54429       0.53009      1.02679  0.04s
    154       0.49475       0.44269      1.11759  0.04s
    155       0.59370       0.52697      1.12663  0.04s
    156       0.59741       0.46292      1.29051  0.04s
    157       0.55332       0.51335      1.07787  0.04s
    158       0.57663       0.48459      1.18993  0.04s
    159       0.63724       0.48243      1.32090  0.04s
    160       0.56891       0.45680      1.24544  0.04s
    161       0.57386       0.46826      1.22551  0.04s
    162       0.60337       0.42179      1.43049  0.04s
    163       0.59047       0.43092      1.37026  0.04s
    164       0.65812       0.41729      1.57711  0.04s
    165       0.69671       0.56529      1.23247  0.04s
    166       0.72812       0.53379      1.36407  0.04s
    167       0.59821       0.44549      1.34280  0.04s
    168       0.64356       0.46170      1.39390  0.04s
    169       0.64315       0.39355      1.63420  0.04s
    170       0.52577       0.54692      0.96132  0.04s
    171       0.69634       0.57530      1.21041  0.04s
    172       0.76177       0.47020      1.62008  0.04s
    173       0.69670       0.49613      1.40428  0.04s
    174       0.64493       0.51781      1.24550  0.04s
    175       0.63774       0.41458      1.53827  0.04s
    176       0.62472       0.43847      1.42479  0.04s
    177       0.70835       0.54124      1.30876  0.04s
    178       0.60505       0.47857      1.26428  0.04s
    179       0.62573       0.56309      1.11124  0.04s
    180       0.71286       0.46391      1.53662  0.04s
    181       0.66008       0.58553      1.12732  0.04s
    182       0.68076       0.47693      1.42736  0.04s
    183       0.54107       0.50192      1.07798  0.04s
    184       0.59860       0.52311      1.14432  0.04s
    185       0.64328       0.47973      1.34093  0.04s
    186       0.71184       0.49364      1.44204  0.04s
    187       0.56436       0.48338      1.16753  0.04s
    188       0.70068       0.69507      1.00807  0.04s
    189       0.66585       0.47129      1.41284  0.04s
    190       0.63695       0.53963      1.18033  0.04s
    191       0.68865       0.43888      1.56912  0.04s
    192       0.69579       0.48903      1.42279  0.04s
    193       0.51642       0.43859      1.17746  0.04s
    194       0.75854       0.46781      1.62146  0.04s
    195       0.61717       0.43482      1.41939  0.04s
    196       0.73182       0.50876      1.43842  0.04s
    197       0.69082       0.51475      1.34203  0.04s
    198       0.52544       0.44289      1.18638  0.04s
    199       0.61897       0.52898      1.17012  0.04s
    200       0.75014       0.55168      1.35974  0.04s
    201       0.64656       0.51130      1.26454  0.04s
    202       0.85032       0.51515      1.65061  0.04s
    203       0.80250       0.45474      1.76477  0.04s
    204       0.71415       0.40891      1.74650  0.04s
    205       0.63684       0.51190      1.24408  0.04s
    206       0.72472       0.40597      1.78516  0.04s
    207       0.80355       0.45093      1.78198  0.04s
    208       0.67274       0.40744      1.65114  0.04s
    209       0.74143       0.42127      1.75999  0.04s
    210       0.56986       0.46123      1.23552  0.04s
    211       0.71960       0.43605      1.65029  0.04s
    212       0.63482       0.43262      1.46740  0.04s
    213       0.75806       0.45731      1.65764  0.04s
    214       0.60426       0.42925      1.40769  0.04s
    215       0.69331       0.54867      1.26363  0.04s
    216       0.63871       0.48177      1.32574  0.04s
    217       0.65462       0.51447      1.27242  0.04s
    218       0.65379       0.46031      1.42033  0.04s
    219       0.59308       0.58460      1.01451  0.04s
    220       0.73173       0.44843      1.63176  0.04s
    221       0.63843       0.42937      1.48689  0.04s
    222       0.71662       0.50535      1.41808  0.04s
    223       0.59712       0.42067      1.41943  0.04s
    224       0.74133       0.46472      1.59522  0.04s
    225       0.60437       0.56615      1.06751  0.04s
    226       0.81115       0.45948      1.76536  0.04s
    227       0.61049       0.41968      1.45466  0.04s
    228       0.63130       0.44116      1.43101  0.04s
    229       0.74160       0.44706      1.65884  0.04s
    230       0.73151       0.42598      1.71725  0.04s
    231       0.63728       0.50211      1.26920  0.04s
    232       0.66585       0.44714      1.48913  0.04s
    233       0.56648       0.51090      1.10879  0.04s
    234       0.57653       0.43852      1.31471  0.04s
    235       0.60373       0.41846      1.44273  0.04s
    236       0.58157       0.48655      1.19531  0.04s
    237       0.68234       0.43975      1.55165  0.04s
    238       0.58890       0.47885      1.22983  0.04s
    239       0.64680       0.48737      1.32713  0.04s
    240       0.62333       0.46746      1.33344  0.04s
    241       0.71032       0.45833      1.54980  0.04s
    242       0.77229       0.45314      1.70428  0.04s
    243       0.69993       0.50020      1.39931  0.04s
    244       0.63189       0.43335      1.45814  0.04s
    245       0.70853       0.46604      1.52031  0.04s
    246       0.61515       0.41826      1.47073  0.04s
    247       0.60737       0.47317      1.28361  0.04s
    248       0.61652       0.45077      1.36770  0.04s
    249       0.63409       0.42554      1.49007  0.04s
    250       0.73797       0.54287      1.35940  0.04s
    251       0.61718       0.47331      1.30397  0.04s
    252       0.63179       0.45475      1.38931  0.04s
    253       0.74023       0.48589      1.52343  0.04s
    254       0.60676       0.49803      1.21831  0.04s
    255       0.62202       0.49872      1.24723  0.04s
    256       0.67443       0.38777      1.73924  0.04s
    257       0.58900       0.46207      1.27472  0.04s
    258       0.61583       0.44608      1.38054  0.04s
    259       0.56354       0.44263      1.27315  0.04s
    260       0.68244       0.42018      1.62418  0.04s
    261       0.62805       0.39501      1.58999  0.04s

In [ ]:
# And now let's also look at whether it looks good or not.
nn1_preds = net1.predict(x_test)
nn1_mse = float(mean_squared_error(nn1_preds, y_test))
nn1_r2 = float(sps.pearsonr(nn1_preds, y_test.reshape(y_test.shape[0],1))[0][0])

cf.scatterplot_results(nn1_preds, y_test, nn1_mse, nn1_r2, DRUG, 'Neural Net', figsize=std)
# plt.savefig('figures/{0} Neural Net.pdf'.format(DRUG), bbox_inches='tight')

In [15]:
# Save models

# Neural net
joblib.dump(net1, 'models/{0} nnet1.pkl'.format(DRUG))

# Random Forest
joblib.dump(rfr, 'models/{0} rfr.pkl'.format(DRUG))

# Gradient Boost
joblib.dump(gbr, 'models/{0} gbr.pkl'.format(DRUG))

# ExtraTrees
joblib.dump(etr, 'models/{0} etr.pkl'.format(DRUG))


Out[15]:
['models/FPV etr.pkl',
 'models/FPV etr.pkl_01.npy',
 'models/FPV etr.pkl_02.npy',
 'models/FPV etr.pkl_03.npy',
 'models/FPV etr.pkl_04.npy',
 'models/FPV etr.pkl_05.npy',
 'models/FPV etr.pkl_06.npy',
 'models/FPV etr.pkl_07.npy',
 'models/FPV etr.pkl_08.npy',
 'models/FPV etr.pkl_09.npy',
 'models/FPV etr.pkl_10.npy',
 'models/FPV etr.pkl_11.npy',
 'models/FPV etr.pkl_12.npy',
 'models/FPV etr.pkl_13.npy',
 'models/FPV etr.pkl_14.npy',
 'models/FPV etr.pkl_15.npy',
 'models/FPV etr.pkl_16.npy',
 'models/FPV etr.pkl_17.npy',
 'models/FPV etr.pkl_18.npy',
 'models/FPV etr.pkl_19.npy',
 'models/FPV etr.pkl_20.npy',
 'models/FPV etr.pkl_21.npy',
 'models/FPV etr.pkl_22.npy',
 'models/FPV etr.pkl_23.npy',
 'models/FPV etr.pkl_24.npy',
 'models/FPV etr.pkl_25.npy',
 'models/FPV etr.pkl_26.npy',
 'models/FPV etr.pkl_27.npy',
 'models/FPV etr.pkl_28.npy',
 'models/FPV etr.pkl_29.npy',
 'models/FPV etr.pkl_30.npy']

In [ ]:


In [ ]: